#!/usr/bin/env python

# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import time
import logging

import omegaconf
import hydra

import pandas as pd
import dash
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly
import plotly.subplots

import queue
import threading
import collections

# include Python files generated by CMake
package_name = "polymetis"

current_dir = os.path.dirname(os.path.realpath(__file__))
catkin_build_path = os.path.abspath(
    os.path.join(current_dir, f"../../../build/{package_name}")
)
sys.path.append(catkin_build_path)

from google.protobuf.internal.encoder import _VarintBytes
from google.protobuf.internal.decoder import _DecodeVarint32

import grpc
import polymetis_pb2
import polymetis_pb2_grpc


def write_protobuf(file, message):
    size = message.ByteSize()
    file.write(_VarintBytes(size))
    file.write(message.SerializeToString())

    return file


def read_protobuf(file, protobuf_message_class):
    if type(file) == str:
        filename = file
        file = open(filename, "rb")
    buf = file.read()
    pos = 0
    while pos < len(buf):
        msg_len, pos = _DecodeVarint32(buf, pos)
        msg_buf = buf[pos : pos + msg_len]
        pos += msg_len

        message = protobuf_message_class()
        message.ParseFromString(msg_buf)

        yield message


class RobotStateVisualizer:
    def __init__(
        self,
        server_ip="localhost",
        server_port=50051,
        max_queue_size=10000,
        downsampling_ratio=100,
        log_keys={"joint_positions", "joint_velocities", "joint_torques_computed"},
        logfile="",
    ):
        self.log_keys = log_keys
        if logfile:
            self.stream_live_data = False
            print(f"Reading data from {logfile}")
            self.state_queue = queue.Queue()
            for i, robot_state in enumerate(
                read_protobuf(logfile, polymetis_pb2.RobotState)
            ):
                self.state_queue.put((i, robot_state))
        else:
            server_connection = f"{server_ip}:{server_port}"
            print(f"Streaming data from server {server_connection}")
            self.stream_live_data = True

            # Setup logfile to write robot_states
            logfile_path = os.path.join(os.getcwd(), "logfile.bin")
            self.logfile = open(logfile_path, "wb")
            print(f"Saving logfile to {logfile_path}")

            # Set up connection
            self.channel = grpc.insecure_channel(server_connection)
            self.grpc_connection = polymetis_pb2_grpc.PolymetisControllerServerStub(
                self.channel
            )
            # Connect to RPC
            self.stream = self.grpc_connection.GetRobotStateStream(
                polymetis_pb2.Empty()
            )

            # Concurrently read from stream in a different thread
            self.state_queue = queue.Queue(maxsize=max_queue_size)
            self.streaming_thread = threading.Thread(
                target=self.update, args=(), daemon=True
            )
            self.streaming_thread.start()
            self.step = 0
            self.downsampling_ratio = downsampling_ratio

    def __del__(self):
        if self.stream_live_data:
            self.channel.close()
            self.logfile.close()

    def update(self):
        if self.stream_live_data:
            for robot_state in self.stream:
                if self.step % self.downsampling_ratio == 0:
                    self.state_queue.put((self.step, robot_state))
                self.step += 1

    def process_queue(self):
        dataframes = {}
        while not self.state_queue.empty():
            step, robot_state = self.state_queue.get()
            if self.stream_live_data:
                write_protobuf(self.logfile, robot_state)
            curr_datetime = robot_state.timestamp.ToDatetime()
            for field, values in robot_state.ListFields():
                if field.name in self.log_keys:
                    try:
                        values_dim = len(values)
                    except TypeError:
                        values_dim = 1

                    if field.name not in dataframes:
                        df = pd.DataFrame(
                            columns=["datetime"] + list(range(values_dim))
                        ).set_index("datetime")
                        dataframes[field.name] = df

                    df = dataframes[field.name]
                    value_dict = {x[0]: x[1] for x in enumerate(values)}
                    df.loc[curr_datetime] = value_dict
        return dataframes


def initialize_graphs(viz, height=1000):
    while viz.state_queue.empty():
        print("Waiting for states...")
        time.sleep(1)

    dataframes = viz.process_queue()
    num_dataframes = len(dataframes)
    max_columns = max([len(df.columns) for df in dataframes.values()])

    # Generate subplot titles
    subplot_titles = []
    row = 1
    for df_name, df in dataframes.items():
        col = 1
        for column in df.columns:
            subplot_titles.append(f"{df_name}: {column}")
            col += 1
        for i in range(col, max_columns + 1):
            subplot_titles.append("")
        row += 1

    fig = plotly.subplots.make_subplots(
        rows=num_dataframes, cols=max_columns, subplot_titles=subplot_titles
    )
    fig.update_layout(height=height, showlegend=False)

    # Add graph objects
    row = 1
    for df_name, df in dataframes.items():
        datetimes = [x.strftime("%H:%M:%S") for x in df.index]
        col = 1
        for column in df.columns:
            fig.append_trace(
                plotly.graph_objects.Scatter(
                    x=datetimes, y=df[column], line=dict(color="black")
                ),
                row=row,
                col=col,
            )
            col += 1
        row += 1

    return fig


@hydra.main(config_path="../polymetis/conf/", config_name="viz")
def main(cfg):
    print(f"Config:\n{omegaconf.OmegaConf.to_yaml(cfg)}")

    # Connect to server...
    visualizer = RobotStateVisualizer(
        server_ip=cfg.server_ip,
        server_port=cfg.server_port,
        max_queue_size=cfg.max_queue_size,
        downsampling_ratio=cfg.downsampling_ratio,
        log_keys=set(cfg.log_keys),
        logfile=cfg.logfile,
    )

    # Create initial figure
    fig = initialize_graphs(visualizer)

    # Create app
    style = omegaconf.OmegaConf.to_container(cfg.app.external_stylesheets, resolve=True)
    app = dash.Dash("controller_manager_visualizer", external_stylesheets=style)

    # Define app layout
    app.layout = html.Div(
        children=[
            html.H1(children="Controller manager visualization"),
            dcc.Graph(id="live-update-graph", figure=fig),
            dcc.Interval(
                id="interval-component",
                interval=cfg.update_interval_seconds * 1000,
                n_intervals=0,
            ),
        ]
    )

    @app.callback(
        Output("live-update-graph", "extendData"),
        [Input("interval-component", "n_intervals")],
    )
    def update_graph_data(n_intervals):
        dataframes = visualizer.process_queue()

        updated_data = {"x": [], "y": []}
        num_traces = 0

        # Add graph objects
        for df_name, df in dataframes.items():
            datetimes = [x.strftime("%H:%M:%S") for x in df.index]
            for column in df.columns:
                num_traces += 1
                updated_data["x"].append(datetimes)
                updated_data["y"].append(df[column])

        trace_indices = list(range(num_traces))
        return [updated_data, trace_indices]

    app.logger.setLevel(logging.WARNING)
    app.run_server(debug=cfg.app.debug, host=cfg.app.host, port=cfg.app.port)


if __name__ == "__main__":
    main()
